import json
import os
import openai
from collections import deque
import logging

class Graph2Text:
    def __init__(self, graph_path, api_key, model="gpt-4o", max_hops=2):
        with open(graph_path, "r") as f:
            self.graph = json.load(f)
        openai.api_key = api_key
        self.model = model
        self.max_hops = max_hops
        logging.basicConfig(level=logging.INFO)

    def get_neighbors(self, node):
        return self.graph.get(node.lower(), [])

    def hop_graph(self, node):
        visited = set()
        queue = deque([(node, 0)])
        facts = []

        while queue:
            current_node, depth = queue.popleft()
            if current_node in visited or depth > self.max_hops:
                continue
            visited.add(current_node)
            neighbors = self.get_neighbors(current_node)
            for relation, obj in neighbors:
                facts.append(f'"{current_node}" {relation} "{obj}"')
                queue.append((obj, depth + 1))

        return facts

    def build_diverse_prompt(self, node, facts):
        """ Generate a prompt asking GPT to create diverse questions based on multi-hop facts """

        prompt = f"""
        The following are facts from a knowledge graph about the entity "{node}":
        {chr(10).join(facts)}

        Generate 5 different types of questions about "{node}":
        1. Descriptive: Ask about general properties or definitions of {node}.
        2. Relational: Ask about its relationships to other entities.
        3. Causal or Explanation-seeking: Ask about causes, effects, or explanations.
        4. Historical or Temporal: Ask about historical context or timeline.
        5. Hypothetical or Predictive: Ask about hypothetical or future scenarios.

        Output format:
        Q1: ...
        Q2: ...
        Q3: ...
        Q4: ...
        Q5: ...
        """
        return prompt

    def ask_gpt(self, prompt):
        response = openai.ChatCompletion.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.7,
            max_tokens=700
        )
        return response['choices'][0]['message']['content'].strip()

    def generate_diverse_qa(self, node):
        facts = self.hop_graph(node)
        if len(facts) == 0:
            logging.warning(f"No facts found for {node}. Skipping.")
            return None
        prompt = self.build_diverse_prompt(node, facts)
        response = self.ask_gpt(prompt)
        return {"entity": node, "facts": facts, "questions": response}

    def generate_qa_dataset(self, nodes, save_path="qa_dataset/qa_diverse.jsonl"):
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, "w") as f_out:
            for node in nodes:
                qa = self.generate_diverse_qa(node)
                if qa:
                    f_out.write(json.dumps(qa) + "\n")
                    logging.info(f"Generated Q&A for: {node}")

        logging.info(f"Saved dataset to {save_path}")
